{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e8224915",
   "metadata": {},
   "source": [
    "# 10 - 匹配\n",
    "\n",
    "## 回归到底在做什么？\n",
    "\n",
    "到目前为止，我们已经看到，当我们进行测试与控制比较时，回归在控制附加变量方面做得非常出色。如果我们有独立性，\\\\((Y_0, Y_1)\\perp T | X\\\\)，那么回归可以通过控制 X 来识别 ATE。回归的方式有点神奇。为了对它有一些直观的了解，让我们记住所有变量 X 都是虚拟变量的情况。如果是这种情况，回归会将数据划分为虚拟单元格并计算测试和控制之间的平均差异。这种均值差异使 Xs 保持不变，因为我们是在 X dummy 的固定单元格中进行的。就好像我们在做 \\\\(E[Y|T=1] - E[Y|T=0] | X=x\\\\)，其中 \\\\(x\\\\) 是一个虚拟单元（所有虚拟单元例如设置为 1）。然后回归结合每个单元格中的估计以产生最终的 ATE。这样做的方法是将权重应用到与该组处理的方差成正比的单元格。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45a25139",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from matplotlib import style\n",
    "from matplotlib import pyplot as plt\n",
    "import statsmodels.formula.api as smf\n",
    "\n",
    "import graphviz as gr\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "style.use(\"fivethirtyeight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2116d885",
   "metadata": {},
   "source": [
    "举个例子，假设我试图估计一种药物的效果，我有 6 个男人和 4 个女人。 我的反应变量是住院天数，我希望我的药物可以降低住院天数。 对男性而言，真正的因果效应是-3，因此该药物将住院时间缩短了3 天。 对于女性，它是-2。 更有趣的是，男性更容易受到这种疾病的影响，并且在医院停留的时间更长。 他们也得到了更多的药物。 6 名男性中只有 1 人没有得到药物。 另一方面，女性对这种疾病的抵抗力更强，所以她们在医院的时间更少。 50% 的女性得到了这种药物。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "61e1771c",
   "metadata": {},
   "outputs": [],
   "source": [
    "drug_example = pd.DataFrame(dict(\n",
    "    sex= [\"M\",\"M\",\"M\",\"M\",\"M\",\"M\", \"W\",\"W\",\"W\",\"W\"],\n",
    "    drug=[1,1,1,1,1,0,   1,0,1,0],\n",
    "    days=[5,5,5,5,5,8,  2,4,2,4]\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea8a9d4e",
   "metadata": {},
   "source": [
    "请注意，治疗和对照的简单比较会产生负面影响，即药物似乎不如实际有效。 这是意料之中的，因为我们已经省略了性别混淆因素。 在这种情况下，估计的 ATE 小于真实的 ATE，因为男性服用的药物更多，更容易受到疾病的影响。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d6ec0084",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.1904761904761898"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drug_example.query(\"drug==1\")[\"days\"].mean() - drug_example.query(\"drug==0\")[\"days\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e81763a",
   "metadata": {},
   "source": [
    "由于男性的真实效果是 -3，女性的真实效果是 -2，ATE 应该是\n",
    "\n",
    "$\n",
    "ATE=\\dfrac{(-3*6) + (-2*4)}{10}=-2.6\n",
    "$\n",
    "\n",
    "该估计是通过 1) 将数据划分为混杂单元格，在这种情况下为男性和女性，2) 估计对每个单元格的影响，以及 3) 将估计值与加权平均值相结合，其中权重是样本量单元格或协变量组。如果我们的数据中男性和女性的大小完全相同，则 ATE 估计值将正好在两组 ATE 的中间，即 -2.5。由于我们的数据集中男性多于女性，因此 ATE 估计值更接近于男性的 ATE。这称为非参数估计，因为它没有假设数据是如何生成的。\n",
    "\n",
    "如果我们使用回归控制性别，我们将添加线性假设。回归还将数据划分为男性和女性，并估计对这两个组的影响。到现在为止还挺好。然而，当谈到组合对每组的影响时，它并没有按样本量来衡量它们。相反，回归使用与该组中治疗的方差成比例的权重。在我们的案例中，男性的治疗差异小于女性，因为对照组中只有一名男性。准确地说，男性的干预变量 `T`的 方差为 \\\\(0.139=1/6*(1 - 1/6)\\\\)，女性样本的方差则为 \\\\(0.25=2/4*(1 - 2/4) \\\\)。因此，在我们的示例中，回归将赋予女性更高的权重，并且 ATE 将更接近于 -2 的女性 ATE。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d2dbccb9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "       <td></td>          <th>coef</th>     <th>std err</th>      <th>t</th>      <th>P>|t|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Intercept</th>   <td>    7.5455</td> <td>    0.188</td> <td>   40.093</td> <td> 0.000</td> <td>    7.100</td> <td>    7.990</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>C(sex)[T.W]</th> <td>   -3.3182</td> <td>    0.176</td> <td>  -18.849</td> <td> 0.000</td> <td>   -3.734</td> <td>   -2.902</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>drug</th>        <td>   -2.4545</td> <td>    0.188</td> <td>  -13.042</td> <td> 0.000</td> <td>   -2.900</td> <td>   -2.010</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.table.SimpleTable'>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smf.ols('days ~ drug + C(sex)', data=drug_example).fit().summary().tables[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a618dbc",
   "metadata": {},
   "source": [
    "这个结果对于虚拟变量来说更直观，但是，回归有它自己独特的运作方式，它在估计连续变量效果的同时假设了连续变量不变。同样对于连续变量，ATE 将指向协变量具有更多方差的方向。\n",
    "\n",
    "所以我们已经看到回归有它的特质。它是线性的、参数化的、喜欢高方差特征……这可能是好是坏，具体取决于上下文。因此，重要的是要了解我们可以用来控制混淆因素的其他技术。它们不仅是你手边因果工具中的一个额外工具，而且了解处理混淆的不同方法可以扩展我们对问题的理解。出于这个原因，我现在向您介绍 **子分类估计器（Subclassification Estimator）！**\n",
    "\n",
    "\n",
    "## 子分类估计器\n",
    "\n",
    "![img](./data/img/matching/explain.png)\n",
    "\n",
    "如果我们想要估计一些因果效应，比如工作培训对收入的影响，**并且**该情况下的干预不是随机分配的时候，我们需要注意混淆因素。比如，因为可能只有更有求职动力的人才会参加培训，因此无论培训如何，他们都会获得更高的收入。我们需要估计培训计划在动机水平大致相同的小组以及我们可能存在的任何其他混淆因素中的效果。\n",
    "\n",
    "更一般的情况下，如果我们想要估计一些因果效应，但由于某些变量 `X` 带来的混淆而很难做到时，我们需要做的是在 `X` 相同的小组内进行干预与对照的比较。如果我们具备条件独立 \\\\((Y_0, Y_1)\\perp T | X\\\\)，那么我们可以将 ATE 写成如下。\n",
    "\n",
    "$\n",
    "ATE = \\int(E[Y|X, T=1] - E[Y|X, T=0])dP(x)\n",
    "$\n",
    "\n",
    "这个积分的作用是遍历特征 X 分布的所有空间，计算所有这些微小空间的均值差异，并将所有内容组合到 ATE 中。另一种看待这一点的方法是考虑一组离散的特征。在这种情况下，我们可以说特征 X 对 K 个不同的单元格 \\\\(\\{X_1, X_2, ..., X_k\\}\\\\) 进行计算，我们正在做的是计算每个单元格中的干预效果并结合他们进入ATE。在这种离散情况下，将积分转换为和，我们可以推导出子分类估计量\n",
    "\n",
    "\n",
    "$\n",
    "\\hat{ATE} = \\sum^K_{i=0}(\\bar{Y}_{k1} - \\bar{Y}_{k0}) * \\dfrac{N_k}{N}\n",
    "$\n",
    "\n",
    "其中变量上的横线符号代表被干预样本的表现平均值，\\\\(Y_{k1}\\\\)，和未干预样本的表现平均值，\\\\(Y_{k0}\\\\)，单元格 k 和 \\\\(N_{k}\\\\ ) 是同一单元格中的观察数。如您所见，我们正在计算每个单元格的本地 ATE，并使用加权平均值将它们组合起来，其中权重是单元格的样本大小。在我们上面的医学例子中，这将是第一个估计，它给了我们-2.6。\n",
    "\n",
    "## 匹配估计器\n",
    "\n",
    "![img](./data/img/matching/its-a-match.png)\n",
    "\n",
    "子分类估计器在实践中用得不多（我们很快就会明白为什么，主要是因为维度诅咒这个原因），但它让我们很好地、直观地了解了因果推理估计器应该做什么，以及它应该如何控制混淆因素。这使我们能够探索其他类型的估计器，例如匹配估计器。\n",
    "\n",
    "这个想法非常相似。由于某种混淆因素 X 使得经过干预的和未干预的样本单元最初无法比较，我可以通过**将每个经过干预的单元与类似的未经干预的单元匹配**来做到这一点。这就像我为每个干预单元找到一个未经干预的双胞胎。通过进行这样的比较，干预过的和未经干预的样本再次变得可比较。\n",
    "\n",
    "举个例子，假设我们试图估计一个练习生训练计划对收入的影响。这是练习生的基本情况："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "54a2edbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>unit</th>\n",
       "      <th>trainees</th>\n",
       "      <th>age</th>\n",
       "      <th>earnings</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>28</td>\n",
       "      <td>17700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>34</td>\n",
       "      <td>10200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>29</td>\n",
       "      <td>14400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>20800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>29</td>\n",
       "      <td>6100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>23</td>\n",
       "      <td>28600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>33</td>\n",
       "      <td>21900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>27</td>\n",
       "      <td>28800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>20300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>26</td>\n",
       "      <td>28100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>9400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>1</td>\n",
       "      <td>27</td>\n",
       "      <td>14300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>29</td>\n",
       "      <td>12500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>19700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>10100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>1</td>\n",
       "      <td>43</td>\n",
       "      <td>10700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>1</td>\n",
       "      <td>28</td>\n",
       "      <td>11500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>18</td>\n",
       "      <td>1</td>\n",
       "      <td>27</td>\n",
       "      <td>10700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>19</td>\n",
       "      <td>1</td>\n",
       "      <td>28</td>\n",
       "      <td>16300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    unit  trainees  age  earnings\n",
       "0      1         1   28     17700\n",
       "1      2         1   34     10200\n",
       "2      3         1   29     14400\n",
       "3      4         1   25     20800\n",
       "4      5         1   29      6100\n",
       "5      6         1   23     28600\n",
       "6      7         1   33     21900\n",
       "7      8         1   27     28800\n",
       "8      9         1   31     20300\n",
       "9     10         1   26     28100\n",
       "10    11         1   25      9400\n",
       "11    12         1   27     14300\n",
       "12    13         1   29     12500\n",
       "13    14         1   24     19700\n",
       "14    15         1   25     10100\n",
       "15    16         1   43     10700\n",
       "16    17         1   28     11500\n",
       "17    18         1   27     10700\n",
       "18    19         1   28     16300"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainee = pd.read_csv(\"./data/trainees.csv\")\n",
    "trainee.query(\"trainees==1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc58f369",
   "metadata": {},
   "source": [
    "下面是非练习生的基本情况："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "31f0301d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>unit</th>\n",
       "      <th>trainees</th>\n",
       "      <th>age</th>\n",
       "      <th>earnings</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>43</td>\n",
       "      <td>20900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>31000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>22</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>21000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>23</td>\n",
       "      <td>0</td>\n",
       "      <td>27</td>\n",
       "      <td>9300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>24</td>\n",
       "      <td>0</td>\n",
       "      <td>54</td>\n",
       "      <td>41100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>25</td>\n",
       "      <td>0</td>\n",
       "      <td>48</td>\n",
       "      <td>29800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>26</td>\n",
       "      <td>0</td>\n",
       "      <td>39</td>\n",
       "      <td>42000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>27</td>\n",
       "      <td>0</td>\n",
       "      <td>28</td>\n",
       "      <td>8800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>28</td>\n",
       "      <td>0</td>\n",
       "      <td>24</td>\n",
       "      <td>25500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>29</td>\n",
       "      <td>0</td>\n",
       "      <td>33</td>\n",
       "      <td>15500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>31</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "      <td>400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>32</td>\n",
       "      <td>0</td>\n",
       "      <td>31</td>\n",
       "      <td>26600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>33</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "      <td>16500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>34</td>\n",
       "      <td>0</td>\n",
       "      <td>34</td>\n",
       "      <td>24200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>35</td>\n",
       "      <td>0</td>\n",
       "      <td>25</td>\n",
       "      <td>23300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>36</td>\n",
       "      <td>0</td>\n",
       "      <td>24</td>\n",
       "      <td>9700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>37</td>\n",
       "      <td>0</td>\n",
       "      <td>29</td>\n",
       "      <td>6200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>38</td>\n",
       "      <td>0</td>\n",
       "      <td>35</td>\n",
       "      <td>30200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>39</td>\n",
       "      <td>0</td>\n",
       "      <td>32</td>\n",
       "      <td>17800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>23</td>\n",
       "      <td>9500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>41</td>\n",
       "      <td>0</td>\n",
       "      <td>32</td>\n",
       "      <td>25900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    unit  trainees  age  earnings\n",
       "19    20         0   43     20900\n",
       "20    21         0   50     31000\n",
       "21    22         0   30     21000\n",
       "22    23         0   27      9300\n",
       "23    24         0   54     41100\n",
       "24    25         0   48     29800\n",
       "25    26         0   39     42000\n",
       "26    27         0   28      8800\n",
       "27    28         0   24     25500\n",
       "28    29         0   33     15500\n",
       "29    31         0   26       400\n",
       "30    32         0   31     26600\n",
       "31    33         0   26     16500\n",
       "32    34         0   34     24200\n",
       "33    35         0   25     23300\n",
       "34    36         0   24      9700\n",
       "35    37         0   29      6200\n",
       "36    38         0   35     30200\n",
       "37    39         0   32     17800\n",
       "38    40         0   23      9500\n",
       "39    41         0   32     25900"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainee.query(\"trainees==0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ddde10f",
   "metadata": {},
   "source": [
    "如果我对均值做一个简单比较，我们会发现那些练习生相比非练习生赚的更少。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "02e2852b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-4297.49373433584"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainee.query(\"trainees==1\")[\"earnings\"].mean() - trainee.query(\"trainees==0\")[\"earnings\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47a2b2cc",
   "metadata": {},
   "source": [
    "但是，如果我们看一下上面的表格，我们会注意到练习生比非练习生年轻得多，这表明年龄可能是一个混淆因素。让我们使用年龄匹配来尝试纠正这一点。我们将从接受干预的人那里取出1号单元，并将其与27号单元配对，因为两者都是28岁。对于单元2，我们将它与单元34配对，而单元3则与单元37配对，对于单元4我们将它与单元35配对...当涉及到5号单元时，我们需要从未接受干预的人中找到29岁的人，但那是37号单元，它已经配对了。这其实不是问题，因为我们可以多次使用相同的单元。如果可以匹配的单位超过1个，我们可以在它们之间随机选择。\n",
    "\n",
    "这是前 7 个单元在匹配后的数据集中的样子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3b6230cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>unit_t_1</th>\n",
       "      <th>trainees_t_1</th>\n",
       "      <th>age</th>\n",
       "      <th>earnings_t_1</th>\n",
       "      <th>unit_t_0</th>\n",
       "      <th>trainees_t_0</th>\n",
       "      <th>earnings_t_0</th>\n",
       "      <th>t1_minuts_t0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>28</td>\n",
       "      <td>17700</td>\n",
       "      <td>27</td>\n",
       "      <td>0</td>\n",
       "      <td>8800</td>\n",
       "      <td>8900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>34</td>\n",
       "      <td>10200</td>\n",
       "      <td>34</td>\n",
       "      <td>0</td>\n",
       "      <td>24200</td>\n",
       "      <td>-14000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>29</td>\n",
       "      <td>14400</td>\n",
       "      <td>37</td>\n",
       "      <td>0</td>\n",
       "      <td>6200</td>\n",
       "      <td>8200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>20800</td>\n",
       "      <td>35</td>\n",
       "      <td>0</td>\n",
       "      <td>23300</td>\n",
       "      <td>-2500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>29</td>\n",
       "      <td>6100</td>\n",
       "      <td>37</td>\n",
       "      <td>0</td>\n",
       "      <td>6200</td>\n",
       "      <td>-100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>23</td>\n",
       "      <td>28600</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "      <td>9500</td>\n",
       "      <td>19100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>33</td>\n",
       "      <td>21900</td>\n",
       "      <td>29</td>\n",
       "      <td>0</td>\n",
       "      <td>15500</td>\n",
       "      <td>6400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   unit_t_1  trainees_t_1  age  earnings_t_1  unit_t_0  trainees_t_0  \\\n",
       "0         1             1   28         17700        27             0   \n",
       "1         2             1   34         10200        34             0   \n",
       "2         3             1   29         14400        37             0   \n",
       "3         4             1   25         20800        35             0   \n",
       "4         5             1   29          6100        37             0   \n",
       "5         6             1   23         28600        40             0   \n",
       "6         7             1   33         21900        29             0   \n",
       "\n",
       "   earnings_t_0  t1_minuts_t0  \n",
       "0          8800          8900  \n",
       "1         24200        -14000  \n",
       "2          6200          8200  \n",
       "3         23300         -2500  \n",
       "4          6200          -100  \n",
       "5          9500         19100  \n",
       "6         15500          6400  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make dataset where no one has the same age\n",
    "unique_on_age = (trainee\n",
    "                 .query(\"trainees==0\")\n",
    "                 .drop_duplicates(\"age\"))\n",
    "\n",
    "matches = (trainee\n",
    "           .query(\"trainees==1\")\n",
    "           .merge(unique_on_age, on=\"age\", how=\"left\", suffixes=(\"_t_1\", \"_t_0\"))\n",
    "           .assign(t1_minuts_t0 = lambda d: d[\"earnings_t_1\"] - d[\"earnings_t_0\"]))\n",
    "\n",
    "matches.head(7)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2a6a505",
   "metadata": {},
   "source": [
    "请注意，最后一列的收益差额为已干预单元和与其匹配的未干预单位之间的差异。如果我们取最后一列的平均值，我们得到控制年龄情况下的ATET估计值。请注意，与之前我们使用简单均值差值的估计值相比，该估计值现在显著为正。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "53ca557e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2457.8947368421054"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "matches[\"t1_minuts_t0\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69d64ff8",
   "metadata": {},
   "source": [
    "但这是一个人为设置的例子，只是为了引入匹配这个概念。实际上，我们通常有多个特征，并且单元间也是不能完全可以匹配。在这种情况下，我们必须定义一些接近度的测量值，以比较单元之间的接近程度。一个常见的指标是欧几里得范数 \\\\(||X_i - X_j||\\\\)。 但是，这种差异在特征的大小变化时并不是保持不变。这意味着，与收入等量纲更大的特征相比，在计算此范数时，类似年纪这种以十分之一为单位的特征的重要性要小得多。因此，在应用范数之前，我们需要缩放特征的值，使它们具有大致相同的比例。\n",
    "\n",
    "定义了距离的测度指标后，我们现在可以将匹配定义为寻找要匹配的样本的最近邻居。在数学方面，我们可以通过以下方式编写匹配估计器：\n",
    "\n",
    "\n",
    "$\n",
    "\\hat{ATE} = \\frac{1}{N} \\sum^N_{i=0}(2T_i - 1)\\big(Y_i - Y_{jm}(i)\\big)\n",
    "$\n",
    "\n",
    "其中 \\\\(Y_{jm}(i)\\\\) 是来自与 \\\\(Y_i\\\\) 最相似的另一个干预组的样本。我们这样做\\\\(2T_i - 1\\\\)次，并以两种方式匹配：从干预组匹配对照组样本，以及从对照组匹配干预样本。\n",
    "\n",
    "为了测试这个估计器，让我们考虑一个医学示例。跟上次一样，我们想找到药物对病人恢复时间长短的效果。不幸的是，这种影响被疾病的严重程度、性别以及年龄所混淆。我们有理由相信，病情更严重的患者接受药物治疗的机会更高。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "378ebd64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>severity</th>\n",
       "      <th>medication</th>\n",
       "      <th>recovery</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>35.049134</td>\n",
       "      <td>0.887658</td>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>41.580323</td>\n",
       "      <td>0.899784</td>\n",
       "      <td>1</td>\n",
       "      <td>49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>28.127491</td>\n",
       "      <td>0.486349</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>36.375033</td>\n",
       "      <td>0.323091</td>\n",
       "      <td>0</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>25.091717</td>\n",
       "      <td>0.209006</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   sex        age  severity  medication  recovery\n",
       "0    0  35.049134  0.887658           1        31\n",
       "1    1  41.580323  0.899784           1        49\n",
       "2    1  28.127491  0.486349           0        38\n",
       "3    1  36.375033  0.323091           0        35\n",
       "4    0  25.091717  0.209006           0        15"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "med = pd.read_csv(\"./data/medicine_impact_recovery.csv\")\n",
    "med.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c7b99e",
   "metadata": {},
   "source": [
    "如果我们看一个简单的均值差，\\\\(E[Y|T=1]-E[Y|T=0]\\\\)，我们得到受到治疗的病人平均需要比未接受治疗的病人多16.9天才能恢复。这可能是由于混淆，因为我们不认为药物会对患者造成伤害。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "11656cbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16.895799546498726"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "med.query(\"medication==1\")[\"recovery\"].mean() - med.query(\"medication==0\")[\"recovery\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75837330",
   "metadata": {},
   "source": [
    "为了纠正这个偏差，我们需要使用匹配来控制X。首先，我们一定要记得缩放我们的特征，否则，类似年龄这样的特征在我们计算两个样本点间距离的时候，会比严重性这种特征有更高的重要性。我们可以通过对特征进行归一化的方式来解决这个问题。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ad8943a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>severity</th>\n",
       "      <th>medication</th>\n",
       "      <th>recovery</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.996980</td>\n",
       "      <td>0.280787</td>\n",
       "      <td>1.459800</td>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.002979</td>\n",
       "      <td>0.865375</td>\n",
       "      <td>1.502164</td>\n",
       "      <td>1</td>\n",
       "      <td>49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.002979</td>\n",
       "      <td>-0.338749</td>\n",
       "      <td>0.057796</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.002979</td>\n",
       "      <td>0.399465</td>\n",
       "      <td>-0.512557</td>\n",
       "      <td>0</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.996980</td>\n",
       "      <td>-0.610473</td>\n",
       "      <td>-0.911125</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        sex       age  severity  medication  recovery\n",
       "0 -0.996980  0.280787  1.459800           1        31\n",
       "1  1.002979  0.865375  1.502164           1        49\n",
       "2  1.002979 -0.338749  0.057796           0        38\n",
       "3  1.002979  0.399465 -0.512557           0        35\n",
       "4 -0.996980 -0.610473 -0.911125           0        15"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# scale features\n",
    "X = [\"severity\", \"age\", \"sex\"]\n",
    "y = \"recovery\"\n",
    "\n",
    "med = med.assign(**{f: (med[f] - med[f].mean())/med[f].std() for f in X})\n",
    "med.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "188e4341",
   "metadata": {},
   "source": [
    "现在，到匹配本身。我们将使用来自 [Sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html) 的 K 最近邻算法，而不是编写匹配函数。此算法通过在估计或训练集中查找最近的数据点来进行预测。\n",
    "\n",
    "为了匹配，我们需要其中的2个函数。一个是; `mt0` ，它将存储未干预的样本，并在被要求时在未处理的点中找到匹配项。另一个，`mt1`，将存储被干预的样本，并在需要时在被干预的样本点中找到匹配项。在此拟合步骤之后，我们可以使用这些 KNN 模型进行预测，从而是我们的匹配样本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7a5aad96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>severity</th>\n",
       "      <th>medication</th>\n",
       "      <th>recovery</th>\n",
       "      <th>match</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.996980</td>\n",
       "      <td>0.280787</td>\n",
       "      <td>1.459800</td>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>39.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.002979</td>\n",
       "      <td>0.865375</td>\n",
       "      <td>1.502164</td>\n",
       "      <td>1</td>\n",
       "      <td>49</td>\n",
       "      <td>52.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>-0.996980</td>\n",
       "      <td>1.495134</td>\n",
       "      <td>1.268540</td>\n",
       "      <td>1</td>\n",
       "      <td>38</td>\n",
       "      <td>46.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1.002979</td>\n",
       "      <td>-0.106534</td>\n",
       "      <td>0.545911</td>\n",
       "      <td>1</td>\n",
       "      <td>34</td>\n",
       "      <td>45.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>-0.996980</td>\n",
       "      <td>0.043034</td>\n",
       "      <td>1.428732</td>\n",
       "      <td>1</td>\n",
       "      <td>30</td>\n",
       "      <td>39.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         sex       age  severity  medication  recovery  match\n",
       "0  -0.996980  0.280787  1.459800           1        31   39.0\n",
       "1   1.002979  0.865375  1.502164           1        49   52.0\n",
       "7  -0.996980  1.495134  1.268540           1        38   46.0\n",
       "10  1.002979 -0.106534  0.545911           1        34   45.0\n",
       "16 -0.996980  0.043034  1.428732           1        30   39.0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "\n",
    "treated = med.query(\"medication==1\")\n",
    "untreated = med.query(\"medication==0\")\n",
    "\n",
    "mt0 = KNeighborsRegressor(n_neighbors=1).fit(untreated[X], untreated[y])\n",
    "mt1 = KNeighborsRegressor(n_neighbors=1).fit(treated[X], treated[y])\n",
    "\n",
    "predicted = pd.concat([\n",
    "    # find matches for the treated looking at the untreated knn model\n",
    "    treated.assign(match=mt0.predict(treated[X])),\n",
    "    \n",
    "    # find matches for the untreated looking at the treated knn model\n",
    "    untreated.assign(match=mt1.predict(untreated[X]))\n",
    "])\n",
    "\n",
    "predicted.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfcfb42b",
   "metadata": {},
   "source": [
    "匹配完成后，我们就可以应用匹配估计器的公式了：\n",
    "$\n",
    "\\hat{ATE} = \\frac{1}{N} \\sum^N_{i=0} (2T_i - 1)\\big(Y_i - Y_{jm}(i)\\big)\n",
    "$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6481511a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.9954"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean((2*predicted[\"medication\"] - 1)*(predicted[\"recovery\"] - predicted[\"match\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7fe2d1a",
   "metadata": {},
   "source": [
    "使用这种匹配，我们可以看到药物的效果不再是增加恢复所需时间。这意味着，控制X后，药物平均将恢复时间减少约1天。这已经是一个巨大的改进，毕竟之前的有偏估计可是预测恢复时间需要增加16.9天。\n",
    "\n",
    "但是，我们仍然可以做得更好。\n",
    "\n",
    "## 匹配偏差\n",
    "\n",
    "\n",
    "事实证明，我们上面设计的匹配估计器还是有偏差的。为了看到这一点，让我们考虑ATE估计器，而不是ATE，只是因为它写起来更简单。其原理也适用于ATE。\n",
    "\n",
    "$\n",
    "\\hat{ATET} = \\frac{1}{N_1}\\sum(Y_i - Y_j(i))\n",
    "$\n",
    "\n",
    "其中 \\\\(N_1\\\\) 是接受治疗的个体数，\\\\(Y_j(i)\\\\) 是未经治疗的单元 i 的匹配。为了检查偏差，我们所做的是希望我们可以应用中心极限定理，以便该定理收敛到平均值为零的正态分布。\n",
    "\n",
    "$\n",
    "\\sqrt{N_1}(\\hat{ATET} - ATET）\n",
    "$\n",
    "\n",
    "但是，这并不总是发生。如果我们定义给定 X 的均值结果，\\\\(\\mu_0(x)=E[Y|X=x， T=0]\\\\)，我们将得到如下结果：（顺便说一句，我省略了证明，因为它不是这里的重点）。\n",
    "\n",
    "$\n",
    "E[\\sqrt{N_1}(\\hat{ATET} - ATET)] = E[\\sqrt{N_1}(\\mu_0(X_i) - \\mu_0(X_j(i)))]\n",
    "$\n",
    "\n",
    "现在，\\\\(\\mu_0(X_i) - \\mu_0(X_j(i))\\\\) 并不是那么容易理解，所以让我们更仔细地看一下。 \\\\(\\mu_0(X_i)\\\\) 是未处理的处理单元的结果 Y 值。因此，它是单元 i 的反事实结果 \\\\(Y_0\\\\)。 \\\\(\\mu_0(X_j(i))\\\\) 是未处理单元 j 的结果，它是单元 i 的匹配项。因此，它也是 \\\\(Y_0\\\\) ，但现在用于单元 j。只有这一次，这是一个事实结果，因为 j 在未处理组中。现在，因为 j 和 i 只是相似，但并不相同，所以这很可能不为零。换句话说，\\\\(X_i \\approx X_j \\\\)。所以，\\\\(Y_{0i} \\approx Y_{0j} \\\\)。\n",
    "\n",
    "随着我们增加样本量，将会有更多的单元来匹配，所以单元 i 和它的匹配 j 之间的差异也会变得更小。但是这种差异会慢慢收敛到零。结果 \\\\(E[\\sqrt{N_1}(\\mu_0(X_i) - \\mu_0(X_j(i)))]\\\\) 可能不会收敛到零，因为 \\\\(\\sqrt{N_1}\\ \\) 增长快于 \\\\((\\mu_0(X_i) - \\mu_0(X_j(i)))\\\\) 减少。\n",
    "\n",
    "当匹配差异很大时，就会出现偏差。幸运的是，我们知道如何纠正它。每个观察都将 \\\\((\\mu_0(X_i) - \\mu_0(X_j(i)))\\\\) 贡献给偏差，所以我们需要做的就是从估计器中的每个匹配比较中减去这个数量。为此，我们可以将 \\\\(\\mu_0(X_j(i))\\\\) 替换为对这个数量 \\\\(\\hat{\\mu_0}(X_j(i))\\\\) 的某种估计，它可以可以通过线性回归等模型获得。这会将 ATET 估计器更新为以下等式\n",
    "\n",
    "\n",
    "$\n",
    "\\hat{ATET} = \\frac{1}{N_1}\\sum \\big((Y_i - Y_{j(i)}) - (\\hat{\\mu_0}(X_i) - \\hat{\\mu_0}(X_{j(i)}))\\big)\n",
    "$\n",
    "\n",
    "\n",
    "where \\\\(\\hat{\\mu_0}(x)\\\\) is some estimative of \\\\(E[Y|X, T=0]\\\\), like a linear regression fitted only on the untreated sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2004709f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "# fit the linear regression model to estimate mu_0(x)\n",
    "ols0 = LinearRegression().fit(untreated[X], untreated[y])\n",
    "ols1 = LinearRegression().fit(treated[X], treated[y])\n",
    "\n",
    "# find the units that match to the treated\n",
    "treated_match_index = mt0.kneighbors(treated[X], n_neighbors=1)[1].ravel()\n",
    "\n",
    "# find the units that match to the untreatd\n",
    "untreated_match_index = mt1.kneighbors(untreated[X], n_neighbors=1)[1].ravel()\n",
    "\n",
    "predicted = pd.concat([\n",
    "    (treated\n",
    "     # find the Y match on the other group\n",
    "     .assign(match=mt0.predict(treated[X])) \n",
    "     \n",
    "     # build the bias correction term\n",
    "     .assign(bias_correct=ols0.predict(treated[X]) - ols0.predict(untreated.iloc[treated_match_index][X]))),\n",
    "    (untreated\n",
    "     .assign(match=mt1.predict(untreated[X]))\n",
    "     .assign(bias_correct=ols1.predict(untreated[X]) - ols1.predict(treated.iloc[untreated_match_index][X])))\n",
    "])\n",
    "\n",
    "predicted.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f078538d",
   "metadata": {},
   "source": [
    "一个直接出现的问题是：这不是破坏了匹配这个出发点吗？如果无论如何我都必须运行线性回归，我为什么不一直就使用回归，而不是这个复杂的匹配模型。有这个想法很正常，所以我应该花一些时间来回答它。\n",
    "\n",
    "![img](./data/img/matching/ubiquitous-ols.png)\n",
    "\n",
    "首先，我们拟合的这个线性回归并没有外推干预维度来获得干预效果。相反，它的目的只是为了纠正偏差。这里的线性回归是局部的，从某种意义上说，如果它看起来像未干预的，它不会尝试查看干预后的情况。它没有进行任何推断。这是留给匹配的部分。估计器的核心仍然是匹配组件。我想在这里说明的一点是，OLS方法相对这个估计量本身其实是一个次要考虑的因素。\n",
    "\n",
    "第二点是匹配是一个非参数估计。它不假设线性或任何类型的参数模型。因此，它比线性回归更灵活，并且可以在线性回归不会的情况下工作，即非线性非常强的情况。\n",
    "\n",
    "这是否意味着您应该只使用匹配？嗯，这是一个棘手的问题。 Alberto Abadie 提出了一个理由，认为更应该偏好匹配方法。它更灵活，一旦你有了代码，运行起来也同样简单。我并不完全相信这一点。至少，Abadie 花了很多时间研究和开发估计器（是的，他是帮助匹配算法发展到现在这个阶段的科学家之一），所以他显然对匹配这种方法有个人的偏好。其次，线性回归的简单性在匹配中是看不到的。线性回归比匹配更容易掌握“保持其他一切不变”的偏导数数学。但这只是我的偏好。老实说，这个问题没有明确的答案。无论如何，回到我们的例子。\n",
    "\n",
    "使用偏差校正公式，我得到以下 ATE 估计。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8026cc18",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean((2*predicted[\"medication\"] - 1)*((predicted[\"recovery\"] - predicted[\"match\"])-predicted[\"bias_correct\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a27064f1",
   "metadata": {},
   "source": [
    "当然，我们还需要围绕这个测量放置一个置信区间，但现在数学理论已经足够了。 在实践中，我们可以简单地使用别人的代码并导入一个匹配的估计器。 这是来自python库 [causalinference](https://github.com/laurencium/causalinference) 的一个。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee843a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from causalinference import CausalModel\n",
    "\n",
    "cm = CausalModel(\n",
    "    Y=med[\"recovery\"].values, \n",
    "    D=med[\"medication\"].values, \n",
    "    X=med[[\"severity\", \"age\", \"sex\"]].values\n",
    ")\n",
    "\n",
    "cm.est_via_matching(matches=1, bias_adj=True)\n",
    "\n",
    "print(cm.estimates)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a76c5e8",
   "metadata": {},
   "source": [
    "最后，我们可以自信地说，我们的药物确实可以减少人们在医院的时间。 ATE 估计值比我的略低，所以可能我的代码并不完美，所以这是尽量使用别人的成熟代码，而不是自己从头构建代码的另一个原因。\n",
    "\n",
    "在我们结束这个话题之前，我只是想稍微解释一下匹配偏差的原因。我们看到当单元和它的匹配不太相似时，匹配是有偏差的。但是是什么导致它们如此不同呢？\n",
    "\n",
    "\n",
    "## 维度的诅咒\n",
    "\n",
    "事实证明，答案非常简单直观。很容易找到与一些特征相匹配的人，比如性别。但如果我们添加更多特征，如年龄、收入、出生城市等，找到匹配项就变得越来越难。更一般地说，我们拥有的特征越多，单位与其匹配之间的距离就越大。\n",
    "\n",
    "这不仅仅是伤害匹配估计器的事情。它与我们之前看到的子分类估计器相关联。早期，在那个人为的医学示例中，对于男人和女人，构建子分类估计器非常容易。那是因为我们只有两个牢房：男人和女人。但如果我们有更多会发生什么？假设我们有 2 个连续的特征，比如年龄和收入，我们设法将它们离散成 5 个桶。这将为我们提供 25 个单元格，或 \\\\(5^2\\\\)。如果我们有 10 个协变量，每个协变量有 3 个桶呢？好像不多吧？好吧，这会给我们 59049 个单元格，或 \\\\(3^{10}\\\\)。很容易看出这如何很快就变得不成比例。这是所有数据科学中普遍存在的现象，被称为**维度的诅咒**！！！\n",
    "\n",
    "![img](./data/img/curse-of-dimensionality.jpg)\n",
    "图片来源：https://deepai.org/machine-learning-glossary-and-terms/curse-of-dimensionality\n",
    "\n",
    "尽管它的名字吓人且自命不凡，但这仅意味着填充特征空间所需的数据点数量随着特征或维度的数量呈指数增长。因此，如果需要 X 个数据点来填充例如 3 个特征空间的空间，则需要成倍增加的点来填充 4 个特征空间。\n",
    "\n",
    "在子分类估计器的上下文中，维数灾难意味着如果我们有很多特征，它就会受到影响。许多特征意味着 X 中有多个单元格。如果有多个单元格，其中一些单元格的数据将非常少。其中一些甚至可能只进行了治疗或仅进行了控制，因此无法估计那里的 ATE，这会破坏我们的估计量。在匹配上下文中，这意味着特征空间将非常空间并且单元将彼此相距很远。这将增加匹配之间的距离并导致偏差问题。\n",
    "\n",
    "至于线性回归，它实际上很好地处理了这个问题。它所做的是将所有特征 X 投影到一个单一的 Y 维度中。然后，它对该投影进行处理和控制比较。因此，以某种方式，线性回归执行某种降维来估计 ATE。它相当优雅。\n",
    "\n",
    "大多数因果模型也有一些方法来处理维度灾难。我不会一直重复自己，但在看它们时应该牢记这一点。例如，当我们在下一节中处理倾向得分时，试着看看它是如何解决这个问题的。\n",
    "\n",
    "## 关键思想\n",
    "\n",
    "我们已经开始了解线性回归的作用以及它如何帮助我们识别因果关系。也就是说，我们理解回归可以看作是将数据集划分为单元，计算每个单元中的 ATE，然后将单元的 ATE 组合成整个数据集的单个 ATE。\n",
    "\n",
    "从那里，我们推导出了一个非常通用的带有子分类的因果推理估计器。我们看到该估计器在实践中如何不是很有用，但它为我们提供了一些关于如何解决因果推理估计问题的有趣见解。这让我们有机会讨论匹配估计器。\n",
    "\n",
    "通过查看每个处理单元并找到与其非常相似并且对于未处理单元类似的未处理对来匹配混杂因素的对照。我们看到了如何使用 KNN 算法来实现这个方法，以及如何使用回归来消除它的偏差。最后，我们讨论了匹配和线性回归之间的区别。我们看到了匹配如何是一种非参数估计器，它不像线性回归那样依赖于线性。\n",
    "\n",
    "最后，我们深入研究了高维数据集的问题，并且我们看到了因果推理方法如何受到它的影响。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
